from utils import Loader
from mydatasets import DatasetForLR
from train import trainLR
from models.LRForMicroVideo import LR

def train():
    config = Loader.load_yaml("/media/Harddisk/goog/compitition-CTR/config/microvideoml.yml")

    loader = DatasetForLR(config, mode="train", need_sample=True)
    train = trainLR(config)
    if config.get('Train').get('offline'):
        model = LR(cfg=config)
        train.train_one_epoch(model, loader)
    else:
        model = LR(cfg=config, mode='test')
        test_loader = DatasetForLR(config, mode="test", need_sample=False)
        train.train_one_epoch(model, loader)
        predict_path = train.valid_one_epoch(model, test_loader)


if __name__ == "__main__":
    train()